/**
 * Copyright 2020 Huawei Technologies Co., Ltd
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
#include "ut/src/runtime/kernel/opencl/common.h"
#include "nnacl/concat_parameter.h"

namespace mindspore::lite::opencl::test {

class TestOpenCL_Concat : public CommonTest {};

namespace {
// PrimitiveType_Concat: src/ops/populate/concat_populate.cc
OpParameter *CreateParameter(int axis) {
  auto *param = test::CreateParameter<ConcatParameter>(schema::PrimitiveType_Concat);
  param->axis_ = axis;
  return reinterpret_cast<OpParameter *>(param);
}
}  // namespace

TEST_F(TestOpenCL_Concat, input2_axis0) {
  std::vector<int> input0_shape = {1, 1, 1, 8};
  std::vector<int> input1_shape = {1, 1, 1, 8};
  std::vector<int> output_shape = {2, 1, 1, 8};
  int axis = 0;
  float input0_data[] = {0.75, 0.06, 0.74, 0.30, 0.9, 0.59, 0.03, 0.37};
  float input1_data[] = {0.5, 0.6, 0.74, 0.23, 0.46, 0.69, 0.13, 0.47};
  float output_data[] = {0.75, 0.06, 0.74, 0.30, 0.9, 0.59, 0.03, 0.37, 0.5, 0.6, 0.74, 0.23, 0.46, 0.69, 0.13, 0.47};
  for (auto fp16_enable : {false, true}) {
    auto *param = CreateParameter(axis);
    TestMain({{input0_shape, input0_data, VAR}, {input1_shape, input1_data, VAR}}, {output_shape, output_data}, param,
             fp16_enable, fp16_enable ? 1e-3 : 1e-9);
  }
}

TEST_F(TestOpenCL_Concat, input2_axis0_shape1) {
  std::vector<int> input0_shape = {1};
  std::vector<int> input1_shape = {1};
  std::vector<int> output_shape = {2};
  int axis = 0;
  float input0_data[] = {0.75};
  float input1_data[] = {0.5};
  float output_data[] = {0.75, 0.5};
  for (auto fp16_enable : {false}) {
    auto *param = CreateParameter(axis);
    TestMain({{input0_shape, input0_data, VAR}, {input1_shape, input1_data, VAR}}, {output_shape, output_data}, param,
             fp16_enable, fp16_enable ? 1e-3 : 1e-9);
  }
}

TEST_F(TestOpenCL_Concat, input2_axis1_Align) {
  std::vector<int> input0_shape = {2, 2, 2, 8};
  std::vector<int> input1_shape = {2, 2, 2, 8};
  std::vector<int> output_shape = {2, 4, 2, 8};
  int axis = 1;
  float input0_data[] = {0.75, 0.06, 0.74, 0.30, 0.9, 0.59, 0.25, 0.39, 0.75, 0.06, 0.74, 0.30, 0.9, 0.59, 0.25, 0.39,
                         0.75, 0.06, 0.74, 0.30, 0.9, 0.59, 0.25, 0.39, 0.75, 0.06, 0.74, 0.30, 0.9, 0.59, 0.25, 0.39,
                         0.75, 0.06, 0.74, 0.30, 0.9, 0.59, 0.25, 0.39, 0.75, 0.06, 0.74, 0.30, 0.9, 0.59, 0.25, 0.39,
                         0.75, 0.06, 0.74, 0.30, 0.9, 0.59, 0.25, 0.39, 0.75, 0.06, 0.74, 0.30, 0.9, 0.59, 0.25, 0.39};
  float input1_data[] = {0.5, 0.6, 0.74, 0.23, 0.46, 0.69, 0.13, 0.41, 0.5, 0.6, 0.74, 0.23, 0.46, 0.69, 0.13, 0.41,
                         0.5, 0.6, 0.74, 0.23, 0.46, 0.69, 0.13, 0.41, 0.5, 0.6, 0.74, 0.23, 0.46, 0.69, 0.13, 0.41,
                         0.5, 0.6, 0.74, 0.23, 0.46, 0.69, 0.13, 0.41, 0.5, 0.6, 0.74, 0.23, 0.46, 0.69, 0.13, 0.41,
                         0.5, 0.6, 0.74, 0.23, 0.46, 0.69, 0.13, 0.41, 0.5, 0.6, 0.74, 0.23, 0.46, 0.69, 0.13, 0.41};
  float output_data[] = {
    0.75, 0.06, 0.74, 0.30, 0.9,  0.59, 0.25, 0.39, 0.75, 0.06, 0.74, 0.30, 0.9,  0.59, 0.25, 0.39, 0.75, 0.06, 0.74,
    0.30, 0.9,  0.59, 0.25, 0.39, 0.75, 0.06, 0.74, 0.30, 0.9,  0.59, 0.25, 0.39, 0.5,  0.6,  0.74, 0.23, 0.46, 0.69,
    0.13, 0.41, 0.5,  0.6,  0.74, 0.23, 0.46, 0.69, 0.13, 0.41, 0.5,  0.6,  0.74, 0.23, 0.46, 0.69, 0.13, 0.41, 0.5,
    0.6,  0.74, 0.23, 0.46, 0.69, 0.13, 0.41, 0.75, 0.06, 0.74, 0.30, 0.9,  0.59, 0.25, 0.39, 0.75, 0.06, 0.74, 0.30,
    0.9,  0.59, 0.25, 0.39, 0.75, 0.06, 0.74, 0.30, 0.9,  0.59, 0.25, 0.39, 0.75, 0.06, 0.74, 0.30, 0.9,  0.59, 0.25,
    0.39, 0.5,  0.6,  0.74, 0.23, 0.46, 0.69, 0.13, 0.41, 0.5,  0.6,  0.74, 0.23, 0.46, 0.69, 0.13, 0.41, 0.5,  0.6,
    0.74, 0.23, 0.46, 0.69, 0.13, 0.41, 0.5,  0.6,  0.74, 0.23, 0.46, 0.69, 0.13, 0.41};
  for (auto fp16_enable : {false, true}) {
    auto *param = CreateParameter(axis);
    TestMain({{input0_shape, input0_data, VAR}, {input1_shape, input1_data, VAR}}, {output_shape, output_data}, param,
             fp16_enable, fp16_enable ? 1e-3 : 1e-9);
  }
}

TEST_F(TestOpenCL_Concat, input6_axis1_Align) {
  std::vector<int> input0_shape = {2, 3, 2, 8};
  std::vector<int> input1_shape = {2, 3, 2, 8};
  std::vector<int> input2_shape = {2, 3, 2, 8};
  std::vector<int> input3_shape = {2, 3, 2, 8};
  std::vector<int> input4_shape = {2, 3, 2, 8};
  std::vector<int> input5_shape = {2, 3, 2, 8};
  std::vector<int> output_shape = {2, 18, 2, 8};
  int axis = 1;
  float input0_data[] = {0.75, 0.06, 0.74, 0.30, 0.9, 0.59, 0.25, 0.39, 0.75, 0.06, 0.74, 0.30, 0.9, 0.59, 0.25, 0.39,
                         0.75, 0.06, 0.74, 0.30, 0.9, 0.59, 0.25, 0.39, 0.75, 0.06, 0.74, 0.30, 0.9, 0.59, 0.25, 0.39,
                         0.75, 0.06, 0.74, 0.30, 0.9, 0.59, 0.25, 0.39, 0.75, 0.06, 0.74, 0.30, 0.9, 0.59, 0.25, 0.39,
                         0.75, 0.06, 0.74, 0.30, 0.9, 0.59, 0.25, 0.39, 0.75, 0.06, 0.74, 0.30, 0.9, 0.59, 0.25, 0.39,
                         0.75, 0.06, 0.74, 0.30, 0.9, 0.59, 0.25, 0.39, 0.75, 0.06, 0.74, 0.30, 0.9, 0.59, 0.25, 0.39,
                         0.75, 0.06, 0.74, 0.30, 0.9, 0.59, 0.25, 0.39, 0.75, 0.06, 0.74, 0.30, 0.9, 0.59, 0.25, 0.39};

  float input1_data[] = {0.5, 0.6, 0.74, 0.23, 0.46, 0.69, 0.13, 0.41, 0.5, 0.6, 0.74, 0.23, 0.46, 0.69, 0.13, 0.41,
                         0.5, 0.6, 0.74, 0.23, 0.46, 0.69, 0.13, 0.41, 0.5, 0.6, 0.74, 0.23, 0.46, 0.69, 0.13, 0.41,
                         0.5, 0.6, 0.74, 0.23, 0.46, 0.69, 0.13, 0.41, 0.5, 0.6, 0.74, 0.23, 0.46, 0.69, 0.13, 0.41,
                         0.5, 0.6, 0.74, 0.23, 0.46, 0.69, 0.13, 0.41, 0.5, 0.6, 0.74, 0.23, 0.46, 0.69, 0.13, 0.41,
                         0.5, 0.6, 0.74, 0.23, 0.46, 0.69, 0.13, 0.41, 0.5, 0.6, 0.74, 0.23, 0.46, 0.69, 0.13, 0.41,
                         0.5, 0.6, 0.74, 0.23, 0.46, 0.69, 0.13, 0.41, 0.5, 0.6, 0.74, 0.23, 0.46, 0.69, 0.13, 0.41};

  float input2_data[] = {0.75, 0.06, 0.74, 0.30, 0.9, 0.59, 0.25, 0.39, 0.75, 0.06, 0.74, 0.30, 0.9, 0.59, 0.25, 0.39,
                         0.75, 0.06, 0.74, 0.30, 0.9, 0.59, 0.25, 0.39, 0.75, 0.06, 0.74, 0.30, 0.9, 0.59, 0.25, 0.39,
                         0.75, 0.06, 0.74, 0.30, 0.9, 0.59, 0.25, 0.39, 0.75, 0.06, 0.74, 0.30, 0.9, 0.59, 0.25, 0.39,
                         0.75, 0.06, 0.74, 0.30, 0.9, 0.59, 0.25, 0.39, 0.75, 0.06, 0.74, 0.30, 0.9, 0.59, 0.25, 0.39,
                         0.75, 0.06, 0.74, 0.30, 0.9, 0.59, 0.25, 0.39, 0.75, 0.06, 0.74, 0.30, 0.9, 0.59, 0.25, 0.39,
                         0.75, 0.06, 0.74, 0.30, 0.9, 0.59, 0.25, 0.39, 0.75, 0.06, 0.74, 0.30, 0.9, 0.59, 0.25, 0.39};

  float input3_data[] = {0.5, 0.6, 0.74, 0.23, 0.46, 0.69, 0.13, 0.41, 0.5, 0.6, 0.74, 0.23, 0.46, 0.69, 0.13, 0.41,
                         0.5, 0.6, 0.74, 0.23, 0.46, 0.69, 0.13, 0.41, 0.5, 0.6, 0.74, 0.23, 0.46, 0.69, 0.13, 0.41,
                         0.5, 0.6, 0.74, 0.23, 0.46, 0.69, 0.13, 0.41, 0.5, 0.6, 0.74, 0.23, 0.46, 0.69, 0.13, 0.41,
                         0.5, 0.6, 0.74, 0.23, 0.46, 0.69, 0.13, 0.41, 0.5, 0.6, 0.74, 0.23, 0.46, 0.69, 0.13, 0.41,
                         0.5, 0.6, 0.74, 0.23, 0.46, 0.69, 0.13, 0.41, 0.5, 0.6, 0.74, 0.23, 0.46, 0.69, 0.13, 0.41,
                         0.5, 0.6, 0.74, 0.23, 0.46, 0.69, 0.13, 0.41, 0.5, 0.6, 0.74, 0.23, 0.46, 0.69, 0.13, 0.41};

  float input4_data[] = {0.75, 0.06, 0.74, 0.30, 0.9, 0.59, 0.25, 0.39, 0.75, 0.06, 0.74, 0.30, 0.9, 0.59, 0.25, 0.39,
                         0.75, 0.06, 0.74, 0.30, 0.9, 0.59, 0.25, 0.39, 0.75, 0.06, 0.74, 0.30, 0.9, 0.59, 0.25, 0.39,
                         0.75, 0.06, 0.74, 0.30, 0.9, 0.59, 0.25, 0.39, 0.75, 0.06, 0.74, 0.30, 0.9, 0.59, 0.25, 0.39,
                         0.75, 0.06, 0.74, 0.30, 0.9, 0.59, 0.25, 0.39, 0.75, 0.06, 0.74, 0.30, 0.9, 0.59, 0.25, 0.39,
                         0.75, 0.06, 0.74, 0.30, 0.9, 0.59, 0.25, 0.39, 0.75, 0.06, 0.74, 0.30, 0.9, 0.59, 0.25, 0.39,
                         0.75, 0.06, 0.74, 0.30, 0.9, 0.59, 0.25, 0.39, 0.75, 0.06, 0.74, 0.30, 0.9, 0.59, 0.25, 0.39};

  float input5_data[] = {0.5, 0.6, 0.74, 0.23, 0.46, 0.69, 0.13, 0.41, 0.5, 0.6, 0.74, 0.23, 0.46, 0.69, 0.13, 0.41,
                         0.5, 0.6, 0.74, 0.23, 0.46, 0.69, 0.13, 0.41, 0.5, 0.6, 0.74, 0.23, 0.46, 0.69, 0.13, 0.41,
                         0.5, 0.6, 0.74, 0.23, 0.46, 0.69, 0.13, 0.41, 0.5, 0.6, 0.74, 0.23, 0.46, 0.69, 0.13, 0.41,
                         0.5, 0.6, 0.74, 0.23, 0.46, 0.69, 0.13, 0.41, 0.5, 0.6, 0.74, 0.23, 0.46, 0.69, 0.13, 0.41,
                         0.5, 0.6, 0.74, 0.23, 0.46, 0.69, 0.13, 0.41, 0.5, 0.6, 0.74, 0.23, 0.46, 0.69, 0.13, 0.41,
                         0.5, 0.6, 0.74, 0.23, 0.46, 0.69, 0.13, 0.41, 0.5, 0.6, 0.74, 0.23, 0.46, 0.69, 0.13, 0.41};
  float output_data[] = {
    0.75, 0.06, 0.74, 0.30, 0.9,  0.59, 0.25, 0.39, 0.75, 0.06, 0.74, 0.30, 0.9,  0.59, 0.25, 0.39, 0.75, 0.06, 0.74,
    0.30, 0.9,  0.59, 0.25, 0.39, 0.75, 0.06, 0.74, 0.30, 0.9,  0.59, 0.25, 0.39, 0.75, 0.06, 0.74, 0.30, 0.9,  0.59,
    0.25, 0.39, 0.75, 0.06, 0.74, 0.30, 0.9,  0.59, 0.25, 0.39, 0.5,  0.6,  0.74, 0.23, 0.46, 0.69, 0.13, 0.41, 0.5,
    0.6,  0.74, 0.23, 0.46, 0.69, 0.13, 0.41, 0.5,  0.6,  0.74, 0.23, 0.46, 0.69, 0.13, 0.41, 0.5,  0.6,  0.74, 0.23,
    0.46, 0.69, 0.13, 0.41, 0.5,  0.6,  0.74, 0.23, 0.46, 0.69, 0.13, 0.41, 0.5,  0.6,  0.74, 0.23, 0.46, 0.69, 0.13,
    0.41, 0.75, 0.06, 0.74, 0.30, 0.9,  0.59, 0.25, 0.39, 0.75, 0.06, 0.74, 0.30, 0.9,  0.59, 0.25, 0.39, 0.75, 0.06,
    0.74, 0.30, 0.9,  0.59, 0.25, 0.39, 0.75, 0.06, 0.74, 0.30, 0.9,  0.59, 0.25, 0.39, 0.75, 0.06, 0.74, 0.30, 0.9,
    0.59, 0.25, 0.39, 0.75, 0.06, 0.74, 0.30, 0.9,  0.59, 0.25, 0.39, 0.5,  0.6,  0.74, 0.23, 0.46, 0.69, 0.13, 0.41,
    0.5,  0.6,  0.74, 0.23, 0.46, 0.69, 0.13, 0.41, 0.5,  0.6,  0.74, 0.23, 0.46, 0.69, 0.13, 0.41, 0.5,  0.6,  0.74,
    0.23, 0.46, 0.69, 0.13, 0.41, 0.5,  0.6,  0.74, 0.23, 0.46, 0.69, 0.13, 0.41, 0.5,  0.6,  0.74, 0.23, 0.46, 0.69,
    0.13, 0.41, 0.75, 0.06, 0.74, 0.30, 0.9,  0.59, 0.25, 0.39, 0.75, 0.06, 0.74, 0.30, 0.9,  0.59, 0.25, 0.39, 0.75,
    0.06, 0.74, 0.30, 0.9,  0.59, 0.25, 0.39, 0.75, 0.06, 0.74, 0.30, 0.9,  0.59, 0.25, 0.39, 0.75, 0.06, 0.74, 0.30,
    0.9,  0.59, 0.25, 0.39, 0.75, 0.06, 0.74, 0.30, 0.9,  0.59, 0.25, 0.39, 0.5,  0.6,  0.74, 0.23, 0.46, 0.69, 0.13,
    0.41, 0.5,  0.6,  0.74, 0.23, 0.46, 0.69, 0.13, 0.41, 0.5,  0.6,  0.74, 0.23, 0.46, 0.69, 0.13, 0.41, 0.5,  0.6,
    0.74, 0.23, 0.46, 0.69, 0.13, 0.41, 0.5,  0.6,  0.74, 0.23, 0.46, 0.69, 0.13, 0.41, 0.5,  0.6,  0.74, 0.23, 0.46,
    0.69, 0.13, 0.41, 0.75, 0.06, 0.74, 0.30, 0.9,  0.59, 0.25, 0.39, 0.75, 0.06, 0.74, 0.30, 0.9,  0.59, 0.25, 0.39,
    0.75, 0.06, 0.74, 0.30, 0.9,  0.59, 0.25, 0.39, 0.75, 0.06, 0.74, 0.30, 0.9,  0.59, 0.25, 0.39, 0.75, 0.06, 0.74,
    0.30, 0.9,  0.59, 0.25, 0.39, 0.75, 0.06, 0.74, 0.30, 0.9,  0.59, 0.25, 0.39, 0.5,  0.6,  0.74, 0.23, 0.46, 0.69,
    0.13, 0.41, 0.5,  0.6,  0.74, 0.23, 0.46, 0.69, 0.13, 0.41, 0.5,  0.6,  0.74, 0.23, 0.46, 0.69, 0.13, 0.41, 0.5,
    0.6,  0.74, 0.23, 0.46, 0.69, 0.13, 0.41, 0.5,  0.6,  0.74, 0.23, 0.46, 0.69, 0.13, 0.41, 0.5,  0.6,  0.74, 0.23,
    0.46, 0.69, 0.13, 0.41, 0.75, 0.06, 0.74, 0.30, 0.9,  0.59, 0.25, 0.39, 0.75, 0.06, 0.74, 0.30, 0.9,  0.59, 0.25,
    0.39, 0.75, 0.06, 0.74, 0.30, 0.9,  0.59, 0.25, 0.39, 0.75, 0.06, 0.74, 0.30, 0.9,  0.59, 0.25, 0.39, 0.75, 0.06,
    0.74, 0.30, 0.9,  0.59, 0.25, 0.39, 0.75, 0.06, 0.74, 0.30, 0.9,  0.59, 0.25, 0.39, 0.5,  0.6,  0.74, 0.23, 0.46,
    0.69, 0.13, 0.41, 0.5,  0.6,  0.74, 0.23, 0.46, 0.69, 0.13, 0.41, 0.5,  0.6,  0.74, 0.23, 0.46, 0.69, 0.13, 0.41,
    0.5,  0.6,  0.74, 0.23, 0.46, 0.69, 0.13, 0.41, 0.5,  0.6,  0.74, 0.23, 0.46, 0.69, 0.13, 0.41, 0.5,  0.6,  0.74,
    0.23, 0.46, 0.69, 0.13, 0.41, 0.75, 0.06, 0.74, 0.30, 0.9,  0.59, 0.25, 0.39, 0.75, 0.06, 0.74, 0.30, 0.9,  0.59,
    0.25, 0.39, 0.75, 0.06, 0.74, 0.30, 0.9,  0.59, 0.25, 0.39, 0.75, 0.06, 0.74, 0.30, 0.9,  0.59, 0.25, 0.39, 0.75,
    0.06, 0.74, 0.30, 0.9,  0.59, 0.25, 0.39, 0.75, 0.06, 0.74, 0.30, 0.9,  0.59, 0.25, 0.39, 0.5,  0.6,  0.74, 0.23,
    0.46, 0.69, 0.13, 0.41, 0.5,  0.6,  0.74, 0.23, 0.46, 0.69, 0.13, 0.41, 0.5,  0.6,  0.74, 0.23, 0.46, 0.69, 0.13,
    0.41, 0.5,  0.6,  0.74, 0.23, 0.46, 0.69, 0.13, 0.41, 0.5,  0.6,  0.74, 0.23, 0.46, 0.69, 0.13, 0.41, 0.5,  0.6,
    0.74, 0.23, 0.46, 0.69, 0.13, 0.41};

  for (auto fp16_enable : {false, true}) {
    auto *param = CreateParameter(axis);
    TestMain({{input0_shape, input0_data, VAR},
              {input1_shape, input1_data, VAR},
              {input2_shape, input2_data, VAR},
              {input3_shape, input3_data, VAR},
              {input4_shape, input4_data, VAR},
              {input5_shape, input5_data, VAR}},
             {output_shape, output_data}, param, fp16_enable, fp16_enable ? 1e-3 : 1e-9);
  }
}

TEST_F(TestOpenCL_Concat, input6_axis2_Align) {
  std::vector<int> input0_shape = {1, 1, 8};
  std::vector<int> input1_shape = {1, 1, 8};
  std::vector<int> input2_shape = {1, 1, 8};
  std::vector<int> input3_shape = {1, 1, 8};
  std::vector<int> input4_shape = {1, 1, 8};
  std::vector<int> input5_shape = {1, 1, 8};
  std::vector<int> output_shape = {1, 1, 48};
  int axis = 2;
  float input0_data[] = {0.75, 0.06, 0.74, 0.30, 0.9, 0.59, 0.13, 0.16};
  float input1_data[] = {0.5, 0.6, 0.74, 0.23, 0.46, 0.69, 0.47, 0.16};
  float input2_data[] = {0.03, 0.37, 0.74, 0.23, 0.46, 0.69, 0.13, 0.16};
  float input3_data[] = {0.52, 0.63, 0.78, 0.43, 0.56, 0.69, 0.87, 0.16};
  float input4_data[] = {0.5, 0.6, 0.74, 0.30, 0.9, 0.59, 0.13, 0.16};
  float input5_data[] = {0.75, 0.06, 0.74, 0.23, 0.46, 0.69, 0.47, 0.16};
  float output_data[] = {0.75, 0.06, 0.74, 0.30, 0.9,  0.59, 0.13, 0.16, 0.5,  0.6,  0.74, 0.23,
                         0.46, 0.69, 0.47, 0.16, 0.03, 0.37, 0.74, 0.23, 0.46, 0.69, 0.13, 0.16,
                         0.52, 0.63, 0.78, 0.43, 0.56, 0.69, 0.87, 0.16, 0.5,  0.6,  0.74, 0.30,
                         0.9,  0.59, 0.13, 0.16, 0.75, 0.06, 0.74, 0.23, 0.46, 0.69, 0.47, 0.16};
  for (auto fp16_enable : {false, true}) {
    auto *param = CreateParameter(axis);
    TestMain({{input0_shape, input0_data, VAR},
              {input1_shape, input1_data, VAR},
              {input2_shape, input2_data, VAR},
              {input3_shape, input3_data, VAR},
              {input4_shape, input4_data, VAR},
              {input5_shape, input5_data, VAR}},
             {output_shape, output_data}, param, fp16_enable, fp16_enable ? 1e-3 : 1e-9);
  }
}

TEST_F(TestOpenCL_Concat, input2_axis3_UnAlign) {
  std::vector<int> input0_shape = {2, 2, 2, 8};
  std::vector<int> input1_shape = {2, 2, 2, 9};
  std::vector<int> output_shape = {2, 2, 2, 17};
  int axis = 3;
  float input0_data[] = {0.75, 0.06, 0.74, 0.30, 0.9, 0.59, 0.25, 0.39, 0.75, 0.06, 0.74, 0.30, 0.9, 0.59, 0.25, 0.39,
                         0.75, 0.06, 0.74, 0.30, 0.9, 0.59, 0.25, 0.39, 0.75, 0.06, 0.74, 0.30, 0.9, 0.59, 0.25, 0.39,
                         0.75, 0.06, 0.74, 0.30, 0.9, 0.59, 0.25, 0.39, 0.75, 0.06, 0.74, 0.30, 0.9, 0.59, 0.25, 0.39,
                         0.75, 0.06, 0.74, 0.30, 0.9, 0.59, 0.25, 0.39, 0.75, 0.06, 0.74, 0.30, 0.9, 0.59, 0.25, 0.39};
  float input1_data[] = {0.5,  0.6,  0.74, 0.23, 0.46, 0.69, 0.13, 0.41, 0.52, 0.5,  0.6,  0.74, 0.23, 0.46, 0.69,
                         0.13, 0.41, 0.52, 0.5,  0.6,  0.74, 0.23, 0.46, 0.69, 0.13, 0.41, 0.52, 0.5,  0.6,  0.74,
                         0.23, 0.46, 0.69, 0.13, 0.41, 0.52, 0.5,  0.6,  0.74, 0.23, 0.46, 0.69, 0.13, 0.41, 0.52,
                         0.5,  0.6,  0.74, 0.23, 0.46, 0.69, 0.13, 0.41, 0.52, 0.5,  0.6,  0.74, 0.23, 0.46, 0.69,
                         0.13, 0.41, 0.52, 0.5,  0.6,  0.74, 0.23, 0.46, 0.69, 0.13, 0.41, 0.52};
  float output_data[] = {
    0.75, 0.06, 0.74, 0.30, 0.9, 0.59, 0.25, 0.39, 0.5, 0.6, 0.74, 0.23, 0.46, 0.69, 0.13, 0.41, 0.52,
    0.75, 0.06, 0.74, 0.30, 0.9, 0.59, 0.25, 0.39, 0.5, 0.6, 0.74, 0.23, 0.46, 0.69, 0.13, 0.41, 0.52,
    0.75, 0.06, 0.74, 0.30, 0.9, 0.59, 0.25, 0.39, 0.5, 0.6, 0.74, 0.23, 0.46, 0.69, 0.13, 0.41, 0.52,
    0.75, 0.06, 0.74, 0.30, 0.9, 0.59, 0.25, 0.39, 0.5, 0.6, 0.74, 0.23, 0.46, 0.69, 0.13, 0.41, 0.52,
    0.75, 0.06, 0.74, 0.30, 0.9, 0.59, 0.25, 0.39, 0.5, 0.6, 0.74, 0.23, 0.46, 0.69, 0.13, 0.41, 0.52,
    0.75, 0.06, 0.74, 0.30, 0.9, 0.59, 0.25, 0.39, 0.5, 0.6, 0.74, 0.23, 0.46, 0.69, 0.13, 0.41, 0.52,
    0.75, 0.06, 0.74, 0.30, 0.9, 0.59, 0.25, 0.39, 0.5, 0.6, 0.74, 0.23, 0.46, 0.69, 0.13, 0.41, 0.52,
    0.75, 0.06, 0.74, 0.30, 0.9, 0.59, 0.25, 0.39, 0.5, 0.6, 0.74, 0.23, 0.46, 0.69, 0.13, 0.41, 0.52,
  };
  for (auto fp16_enable : {false, true}) {
    auto *param = CreateParameter(axis);
    TestMain({{input0_shape, input0_data, VAR}, {input1_shape, input1_data, VAR}}, {output_shape, output_data}, param,
             fp16_enable, fp16_enable ? 1e-3 : 1e-9);
  }
}

TEST_F(TestOpenCL_Concat, input3_axis1_UnAlign) {
  std::vector<int> input0_shape = {1, 6};
  std::vector<int> input1_shape = {1, 7};
  std::vector<int> input2_shape = {1, 8};
  std::vector<int> output_shape = {1, 21};
  int axis = 1;
  float input0_data[] = {0.75, 0.06, 0.74, 0.30, 0.9, 0.59};
  float input1_data[] = {0.5, 0.6, 0.74, 0.23, 0.46, 0.69, 0.47};
  float input2_data[] = {0.03, 0.37, 0.74, 0.23, 0.46, 0.69, 0.13, 0.13};
  float output_data[] = {0.75, 0.06, 0.74, 0.30, 0.9,  0.59, 0.5,  0.6,  0.74, 0.23, 0.46,
                         0.69, 0.47, 0.03, 0.37, 0.74, 0.23, 0.46, 0.69, 0.13, 0.13};
  for (auto fp16_enable : {false, true}) {
    auto *param = CreateParameter(axis);
    TestMain({{input0_shape, input0_data, VAR}, {input1_shape, input1_data, VAR}, {input2_shape, input2_data, VAR}},
             {output_shape, output_data}, param, fp16_enable, fp16_enable ? 1e-3 : 1e-9);
  }
}

TEST_F(TestOpenCL_Concat, input4_axis3_UnAlign) {
  std::vector<int> input0_shape = {1, 1, 1, 6};
  std::vector<int> input1_shape = {1, 1, 1, 7};
  std::vector<int> input2_shape = {1, 1, 1, 8};
  std::vector<int> input3_shape = {1, 1, 1, 9};
  std::vector<int> output_shape = {1, 1, 1, 30};
  int axis = -1;
  float input0_data[] = {0.75, 0.06, 0.74, 0.30, 0.9, 0.59};
  float input1_data[] = {0.5, 0.6, 0.74, 0.23, 0.46, 0.69, 0.47};
  float input2_data[] = {0.03, 0.37, 0.74, 0.23, 0.46, 0.69, 0.13, 0.13};
  float input3_data[] = {0.03, 0.37, 0.74, 0.23, 0.46, 0.69, 0.13, 0.13, 0.26};
  float output_data[] = {0.75, 0.06, 0.74, 0.30, 0.9,  0.59, 0.5,  0.6,  0.74, 0.23, 0.46, 0.69, 0.47, 0.03, 0.37,
                         0.74, 0.23, 0.46, 0.69, 0.13, 0.13, 0.03, 0.37, 0.74, 0.23, 0.46, 0.69, 0.13, 0.13, 0.26};
  for (auto fp16_enable : {false, true}) {
    auto *param = CreateParameter(axis);
    TestMain({{input0_shape, input0_data, VAR},
              {input1_shape, input1_data, VAR},
              {input2_shape, input2_data, VAR},
              {input3_shape, input3_data, VAR}},
             {output_shape, output_data}, param, fp16_enable, fp16_enable ? 1e-3 : 1e-9);
  }
}

TEST_F(TestOpenCL_Concat, input5_axis3_UnAlign) {
  std::vector<int> input0_shape = {1, 1, 1, 6};
  std::vector<int> input1_shape = {1, 1, 1, 7};
  std::vector<int> input2_shape = {1, 1, 1, 8};
  std::vector<int> input3_shape = {1, 1, 1, 9};
  std::vector<int> input4_shape = {1, 1, 1, 10};
  std::vector<int> output_shape = {1, 1, 1, 40};
  int axis = 3;
  float input0_data[] = {0.75, 0.06, 0.74, 0.30, 0.9, 0.59};
  float input1_data[] = {0.5, 0.6, 0.74, 0.23, 0.46, 0.69, 0.47};
  float input2_data[] = {0.03, 0.37, 0.74, 0.23, 0.46, 0.69, 0.13, 0.13};
  float input3_data[] = {0.03, 0.37, 0.74, 0.23, 0.46, 0.69, 0.13, 0.13, 0.26};
  float input4_data[] = {0.06, 0.47, 0.74, 0.23, 0.56, 0.69, 0.73, 0.13, 0.96, 0.78};
  float output_data[] = {0.75, 0.06, 0.74, 0.30, 0.9,  0.59, 0.5,  0.6,  0.74, 0.23, 0.46, 0.69, 0.47, 0.03,
                         0.37, 0.74, 0.23, 0.46, 0.69, 0.13, 0.13, 0.03, 0.37, 0.74, 0.23, 0.46, 0.69, 0.13,
                         0.13, 0.26, 0.06, 0.47, 0.74, 0.23, 0.56, 0.69, 0.73, 0.13, 0.96, 0.78};
  for (auto fp16_enable : {false, true}) {
    auto *param = CreateParameter(axis);
    TestMain({{input0_shape, input0_data, VAR},
              {input1_shape, input1_data, VAR},
              {input2_shape, input2_data, VAR},
              {input3_shape, input3_data, VAR},
              {input4_shape, input4_data, VAR}},
             {output_shape, output_data}, param, fp16_enable, fp16_enable ? 1e-3 : 1e-9);
  }
}

TEST_F(TestOpenCL_Concat, input6_axis3_UnAlign) {
  std::vector<int> input0_shape = {1, 1, 1, 6};
  std::vector<int> input1_shape = {1, 1, 1, 7};
  std::vector<int> input2_shape = {1, 1, 1, 8};
  std::vector<int> input3_shape = {1, 1, 1, 9};
  std::vector<int> input4_shape = {1, 1, 1, 10};
  std::vector<int> input5_shape = {1, 1, 1, 11};
  std::vector<int> output_shape = {1, 1, 1, 51};
  int axis = 3;
  float input0_data[] = {0.75, 0.06, 0.74, 0.30, 0.9, 0.59};
  float input1_data[] = {0.5, 0.6, 0.74, 0.23, 0.46, 0.69, 0.47};
  float input2_data[] = {0.03, 0.37, 0.74, 0.23, 0.46, 0.69, 0.13, 0.13};
  float input3_data[] = {0.03, 0.37, 0.74, 0.23, 0.46, 0.69, 0.13, 0.13, 0.26};
  float input4_data[] = {0.06, 0.47, 0.74, 0.23, 0.56, 0.69, 0.73, 0.13, 0.96, 0.78};
  float input5_data[] = {0.16, 0.77, 0.84, 0.53, 0.36, 0.29, 0.53, 0.23, 0.86, 0.48, 0.36};
  float output_data[] = {0.75, 0.06, 0.74, 0.30, 0.9,  0.59, 0.5,  0.6,  0.74, 0.23, 0.46, 0.69, 0.47,
                         0.03, 0.37, 0.74, 0.23, 0.46, 0.69, 0.13, 0.13, 0.03, 0.37, 0.74, 0.23, 0.46,
                         0.69, 0.13, 0.13, 0.26, 0.06, 0.47, 0.74, 0.23, 0.56, 0.69, 0.73, 0.13, 0.96,
                         0.78, 0.16, 0.77, 0.84, 0.53, 0.36, 0.29, 0.53, 0.23, 0.86, 0.48, 0.36};
  for (auto fp16_enable : {false, true}) {
    auto *param = CreateParameter(axis);
    TestMain({{input0_shape, input0_data, VAR},
              {input1_shape, input1_data, VAR},
              {input2_shape, input2_data, VAR},
              {input3_shape, input3_data, VAR},
              {input4_shape, input4_data, VAR},
              {input5_shape, input5_data, VAR}},
             {output_shape, output_data}, param, fp16_enable, fp16_enable ? 1e-3 : 1e-9);
  }
}

}  // namespace mindspore::lite::opencl::test
